{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "91b23aa5-bf39-4d01-8840-6d95e16829ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/daiyuxin/anaconda3/envs/lvyufeng/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "from mindspore import nn\n",
    "from mindspore import ops\n",
    "from mindspore.common.initializer import initializer\n",
    "\n",
    "from mindnlp import load_dataset, process, Vocab\n",
    "from mindnlp._legacy.abc import Seq2vecModel\n",
    "from mindnlp.engine import Trainer\n",
    "from mindnlp.metrics import Accuracy\n",
    "from mindnlp.modules import Glove, RNNEncoder, StaticLSTM\n",
    "from mindnlp.transforms import BasicTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e27f0fa8-c881-42ae-b162-fa46d8c724a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['text', 'label']\n"
     ]
    }
   ],
   "source": [
    "# load datasets\n",
    "imdb_train, imdb_test = load_dataset('imdb', shuffle=True)\n",
    "print(imdb_train.get_col_names())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3de8db32-1c2a-4b14-b3c4-aa7e7773d113",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BasicTokenizer(True)\n",
    "vocab = Vocab.from_pretrained(name=\"glove.6B.100d\")\n",
    "\n",
    "imdb_train = process('imdb', imdb_train, tokenizer=tokenizer, vocab=vocab, \\\n",
    "                     bucket_boundaries=[400, 500], max_len=600, drop_remainder=True)\n",
    "imdb_test = process('imdb', imdb_test, tokenizer=tokenizer, vocab=vocab, \\\n",
    "                     bucket_boundaries=[400, 500], max_len=600, drop_remainder=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b703055d-a044-4053-b067-53a6bcbbfeb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SentimentClassification(Seq2vecModel):\n",
    "    def construct(self, text):\n",
    "        _, (hidden, _), _ = self.encoder(text)\n",
    "        context = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)\n",
    "        output = self.head(context)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f2d9dd6d-bc88-4be0-a452-9bceff227087",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-parameters\n",
    "hidden_size = 256\n",
    "output_size = 2\n",
    "num_layers = 2\n",
    "bidirectional = True\n",
    "dropout = 0.5\n",
    "lr = 5e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "db621c85-dcf1-47fe-bf21-a472d66d2aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# build encoder\n",
    "embedding = Glove.from_pretrained('6B', 100, special_tokens=[\"<unk>\", \"<pad>\"])\n",
    "lstm_layer = StaticLSTM(100, hidden_size, num_layers=num_layers, batch_first=True,\n",
    "                     dropout=dropout, bidirectional=bidirectional)\n",
    "encoder = RNNEncoder(embedding, lstm_layer)\n",
    "\n",
    "# build head\n",
    "head = nn.Sequential([\n",
    "    nn.Dropout(p=dropout),\n",
    "    nn.Dense(hidden_size * 2, output_size)\n",
    "])\n",
    "\n",
    "# build network\n",
    "network = SentimentClassification(encoder, head)\n",
    "loss = nn.CrossEntropyLoss()\n",
    "optimizer = nn.Adam(network.trainable_params(), learning_rate=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4eb3b5d5-3ab7-48bf-8afa-3c104a78da1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize_weights(m):\n",
    "    if isinstance(m, nn.Dense):\n",
    "        m.weight.set_data(initializer('xavier_normal', m.weight.shape, m.weight.dtype))\n",
    "        m.bias.set_data(initializer('zeros', m.bias.shape, m.bias.dtype))\n",
    "    elif isinstance(m, StaticLSTM):\n",
    "        for name, param in m.parameters_and_names():\n",
    "            if 'bias' in name:\n",
    "                param.set_data(initializer('zeros', param.shape, param.dtype))\n",
    "            elif 'weight' in name:\n",
    "                param.set_data(initializer('orthogonal', param.shape, param.dtype))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9da68e43-9dd8-4106-ab0f-45f3224ccbd8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SentimentClassification<\n",
       "  (encoder): RNNEncoder<\n",
       "    (embedding): Glove<\n",
       "      (dropout_layer): Dropout<p=0.0>\n",
       "      >\n",
       "    (rnn): StaticLSTM<\n",
       "      (rnn): MultiLayerRNN<\n",
       "        (cell_list): CellList<\n",
       "          (0): SingleLSTMLayer_GPU<>\n",
       "          (1): SingleLSTMLayer_GPU<>\n",
       "          >\n",
       "        (dropout): Dropout<p=0.5>\n",
       "        >\n",
       "      >\n",
       "    >\n",
       "  (head): Sequential<\n",
       "    (0): Dropout<p=0.5>\n",
       "    (1): Dense<input_channels=512, output_channels=2, has_bias=True>\n",
       "    >\n",
       "  >"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network.apply(initialize_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d39bcaf4-1a4b-41bf-8562-4ebea52434bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|█████████████████████████████████████████████████████████████████████| 390/390 [01:16<00:00,  5.12it/s, loss=0.648818]\n",
      "Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:16<00:00, 24.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.54512}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|████████████████████████████████████████████████████████████████████| 390/390 [01:12<00:00,  5.37it/s, loss=0.6403419]\n",
      "Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:14<00:00, 27.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.76188}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2:   0%|                                                                                              | 0/390 [00:00<?, ?it/s][WARNING] PYNATIVE(592184,7f8c217f0740,python):2023-03-22-23:56:50.349.509 [mindspore/ccsrc/pipeline/pynative/grad/grad.cc:1252] CheckAlreadyRun] The input info of this cell has changed, forward process will run again\n",
      "Epoch 2: 100%|████████████████████████████████████████████████████████████████████| 390/390 [01:13<00:00,  5.32it/s, loss=0.5273511]\n",
      "Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:15<00:00, 25.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.82848}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3:   0%|                                                                                              | 0/390 [00:00<?, ?it/s][WARNING] PYNATIVE(592184,7f8c217f0740,python):2023-03-22-23:58:18.780.065 [mindspore/ccsrc/pipeline/pynative/grad/grad.cc:1252] CheckAlreadyRun] The input info of this cell has changed, forward process will run again\n",
      "Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 390/390 [01:13<00:00,  5.29it/s, loss=0.36224687]\n",
      "Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:14<00:00, 27.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.85832}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4:   0%|                                                                                              | 0/390 [00:00<?, ?it/s][WARNING] PYNATIVE(592184,7f8c217f0740,python):2023-03-22-23:59:46.799.922 [mindspore/ccsrc/pipeline/pynative/grad/grad.cc:1252] CheckAlreadyRun] The input info of this cell has changed, forward process will run again\n",
      "Epoch 4: 100%|███████████████████████████████████████████████████████████████████| 390/390 [01:13<00:00,  5.30it/s, loss=0.27380344]\n",
      "Evaluate: 100%|███████████████████████████████████████████████████████████████████████████████████| 392/392 [00:14<00:00, 26.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.86688}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# define metrics\n",
    "metric = Accuracy()\n",
    "\n",
    "# define trainer\n",
    "trainer = Trainer(network=network, train_dataset=imdb_train, eval_dataset=imdb_test, metrics=metric,\n",
    "                  epochs=5, loss_fn=loss, optimizer=optimizer)\n",
    "trainer.run(tgt_columns=\"label\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
